# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Train a EfficientNets on ImageNet on TPU."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time
import argparse
from absl import app
from absl import flags
from absl import logging
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow.compat.v2 as tf2  # used for summaries only.

import imagenet_input
import model_builder_factory
import utils
# pylint: disable=g-direct-tensorflow-import
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.estimator import estimator
from npu_bridge.npu_init import *
from npu_bridge.estimator.npu.npu_config import NPURunConfig
from npu_bridge.estimator import npu_ops
from npu_bridge.estimator.npu.npu_estimator import NPUEstimator, NPUEstimatorSpec

# pylint: enable=g-direct-tensorflow-import


FLAGS = flags.FLAGS

# Required parameters
flags.DEFINE_string(
    'data_dir', default=None,
    help='The directory where the ImageNet input data is stored.'
    'Please see the README.md for the expected data format.')

flags.DEFINE_string(
    'model_dir', default=None,
    help='The directory where the model and training/evaluation summaries are stored.')

flags.DEFINE_string(
    'dump_dir', default="./dump_dir",
    help='The directory where the dump data are stored.')

flags.DEFINE_string(
    'profiling_dir', default="./profiling_dir",
    help='The directory where the profiling data are stored.'
    'You need create it first and have the permissions of read and write.')

flags.DEFINE_string("obs_dir", None, "obs dir")

flags.DEFINE_string(
    'model_name',
    default='efficientnet-condconv-b0-8e',
    help='The model name among existing configurations.')

flags.DEFINE_integer(
    'batch_size', default=128, help='Batch size for me.')


# The following params only useful on NPU chip mode
flags.DEFINE_boolean("npu_dump_data", False, "dump data for precision or not")
flags.DEFINE_boolean("npu_dump_graph", False, "dump graph or not")
flags.DEFINE_boolean("npu_profiling", False,
                     "profiling for performance or not")
flags.DEFINE_boolean("npu_auto_tune", False,
                     "auto tune or not. And you must set tune_bank_path param.")


FAKE_DATA_DIR = 'gs://cloud-tpu-test-datasets/fake_imagenet'

flags.DEFINE_bool(
    'use_tpu', default=False,
    help=('Use TPU to execute the model for training and evaluation. If'
          ' --use_tpu=false, will use whatever devices are available to'
          ' TensorFlow by default (e.g. CPU and GPU)'))

flags.DEFINE_string('tpu_job_name', None, help=('Name of worker binary.'))

# Cloud TPU Cluster Resolvers
flags.DEFINE_string(
    'tpu', default=None,
    help='The Cloud TPU to use for training. This should be either the name '
    'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.')

flags.DEFINE_string(
    'gcp_project', default=None,
    help='Project name for the Cloud TPU-enabled project. If not specified, we '
    'will attempt to automatically detect the GCE project from metadata.')

flags.DEFINE_string(
    'tpu_zone', default=None,
    help='GCE zone where the Cloud TPU is located in. If not specified, we '
    'will attempt to automatically detect the GCE project from metadata.')

# Model specific flags

flags.DEFINE_integer(
    'holdout_shards',
    default=None,
    help=('Number of holdout shards for validation. Recommended 20.'))

flags.DEFINE_string('eval_name', default=None, help=('Evaluation name.'))

flags.DEFINE_bool(
    'archive_ckpt', default=True, help=('If true, archive the best ckpt.'))

flags.DEFINE_string(
    'mode', default='train_and_eval',
    help='One of {"train_and_eval", "train", "eval"}.')

flags.DEFINE_string(
    'augment_name', default=None,
    help='`string` that is the name of the augmentation method'
         'to apply to the image. `autoaugment` if AutoAugment is to be used or'
         '`randaugment` if RandAugment is to be used. If the value is `None` no'
         'augmentation method will be applied applied. See autoaugment.py for  '
         'more details.')


flags.DEFINE_integer(
    'randaug_num_layers', default=2,
    help='If RandAug is used, what should the number of layers be.'
         'See autoaugment.py for detailed description.')

flags.DEFINE_integer(
    'randaug_magnitude', default=10,
    help='If RandAug is used, what should the magnitude be. '
         'See autoaugment.py for detailed description.')


flags.DEFINE_integer(
    'train_steps', default=250000,
    help=('The number of steps to use for training. Default is 218949 steps'
          ' which is approximately 350 epochs at batch size 96. This flag'
          ' should be adjusted according to the --train_batch_size flag.'))

flags.DEFINE_integer(
    'input_image_size', default=None,
    help='Input image size: it depends on specific model name.')

flags.DEFINE_integer(
    'train_batch_size', default=128, help='Batch size for training.')

flags.DEFINE_integer(
    'eval_batch_size', default=8, help='Batch size for evaluation.')

flags.DEFINE_integer(
    'batch_norm_batch_size',
    default=8,
    help='Per-group batch size for distributed batch normalization.')

flags.DEFINE_integer(
    'num_train_images', default=1281167, help='Size of training data set.')

flags.DEFINE_integer(
    'num_eval_images', default=50000, help='Size of evaluation data set.')

flags.DEFINE_integer(
    'steps_per_eval', default=6255,
    help=('Controls how often evaluation is performed. Since evaluation is'
          ' fairly expensive, it is advised to evaluate as infrequently as'
          ' possible (i.e. up to --train_steps, which evaluates the model only'
          ' after finishing the entire training regime).'))

flags.DEFINE_integer(
    'eval_timeout',
    default=None,
    help='Maximum seconds between checkpoints before evaluation terminates.')

flags.DEFINE_bool(
    'skip_host_call', default=False,
    help=('Skip the host_call which is executed every training step. This is'
          ' generally used for generating training summaries (train loss,'
          ' learning rate, etc...). When --skip_host_call=false, there could'
          ' be a performance drop if host_call function is slow and cannot'
          ' keep up with the TPU-side computation.'))

flags.DEFINE_integer(
    'iterations_per_loop', default=1251,
    help=('Number of steps to run on TPU before outfeeding metrics to the CPU.'
          ' If the number of iterations in the loop would exceed the number of'
          ' train steps, the loop will exit before reaching'
          ' --iterations_per_loop. The larger this value is, the higher the'
          ' utilization on the TPU.'))

flags.DEFINE_integer(
    'num_parallel_calls', default=64,
    help=('Number of parallel threads in CPU for the input pipeline'))

flags.DEFINE_string(
    'bigtable_project', None,
    'The Cloud Bigtable project.  If None, --gcp_project will be used.')
flags.DEFINE_string(
    'bigtable_instance', None,
    'The Cloud Bigtable instance to load data from.')
flags.DEFINE_string(
    'bigtable_table', 'imagenet',
    'The Cloud Bigtable table to load data from.')
flags.DEFINE_string(
    'bigtable_train_prefix', 'train_',
    'The prefix identifying training rows.')
flags.DEFINE_string(
    'bigtable_eval_prefix', 'validation_',
    'The prefix identifying evaluation rows.')
flags.DEFINE_string(
    'bigtable_column_family', 'tfexample',
    'The column family storing TFExamples.')
flags.DEFINE_string(
    'bigtable_column_qualifier', 'example',
    'The column name storing TFExamples.')

flags.DEFINE_string(
    'data_format', default='channels_last',
    help=('A flag to override the data format used in the model. The value'
          ' is either channels_first or channels_last. To run the network on'
          ' CPU or TPU, channels_last should be used. For GPU, channels_first'
          ' will improve performance.'))
flags.DEFINE_integer(
    'num_label_classes', default=1000, help='Number of classes, at least 2')

flags.DEFINE_float(
    'batch_norm_momentum',
    default=None,
    help=('Batch normalization layer momentum of moving average to override.'))
flags.DEFINE_float(
    'batch_norm_epsilon',
    default=None,
    help=('Batch normalization layer epsilon to override..'))

flags.DEFINE_bool(
    'transpose_input', default=True,
    help='Use TPU double transpose optimization')

flags.DEFINE_bool(
    'use_bfloat16',
    default=False,
    help=('Whether to use bfloat16 as activation for training.'))

flags.DEFINE_string(
    'export_dir',
    default=None,
    help=('The directory where the exported SavedModel will be stored.'))
flags.DEFINE_bool(
    'export_to_tpu', default=False,
    help=('Whether to export additional metagraph with "serve, tpu" tags'
          ' in addition to "serve" only metagraph.'))

flags.DEFINE_float(
    'base_learning_rate',
    default=0.016,
    help=('Base learning rate when train batch size is 256.'))

flags.DEFINE_float('lr_decay_epoch', default=2.4, help='LR decay epoch.')

flags.DEFINE_float(
    'moving_average_decay', default=0.9999,
    help=('Moving average decay rate.'))

flags.DEFINE_float(
    'weight_decay', default=1e-5,
    help=('Weight decay coefficiant for l2 regularization.'))

flags.DEFINE_float(
    'label_smoothing', default=0.1,
    help=('Label smoothing parameter used in the softmax_cross_entropy'))

flags.DEFINE_float(
    'dropout_rate', default=None,
    help=('Dropout rate for the final output layer.'))

flags.DEFINE_float(
    'survival_prob', default=None,
    help=('Drop connect rate for the network.'))

flags.DEFINE_float(
    'mixup_alpha',
    default=0.0,
    help=('Alpha parameter for mixup regularization, 0.0 to disable.'))

flags.DEFINE_integer('log_step_count_steps', 64, 'The number of steps at '
                     'which the global step information is logged.')

flags.DEFINE_bool(
    'use_cache', default=False, help=('Enable cache for training input.'))

flags.DEFINE_float(
    'depth_coefficient', default=None,
    help=('Depth coefficient for scaling number of layers.'))

flags.DEFINE_float(
    'width_coefficient', default=None,
    help=('Width coefficient for scaling channel size.'))

flags.DEFINE_bool(
    'use_async_checkpointing', default=True, help=('Enable async checkpoint'))

flags.DEFINE_string(
    'optimizer',
    default='rmsprop',
    help='The optimizer to use. Can be either rmsprop, sgd, momentum, or lars.')

flags.DEFINE_string(
    'lr_schedule', default='exponential', help=('learning rate schedule'))

flags.DEFINE_float(
    'lr_decay_factor', default=0.97, help=('Learning rate decay factor.'))

flags.DEFINE_float(
    'lr_warmup_epochs', default=5, help=('warmup epochs for learning rate'))

flags.DEFINE_float(
    'lars_weight_decay',
    default=0.00001,
    help=('Weight decay for LARS optimizer.'))

flags.DEFINE_float(
    'lars_epsilon', default=0.0, help=('Epsilon for LARS optimizer.'))

flags.DEFINE_integer(
    'num_replicas', default=32, help=('Number of TPU replicas.'))


def model_fn(features, labels, mode, params):
    """The model_fn to be used with TPUEstimator.

    Args:
      features: `Tensor` of batched images.
      labels: `Tensor` of one hot labels for the data samples
      mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
      params: `dict` of parameters passed to the model from the TPUEstimator,
          `params['batch_size']` is always provided and should be used as the
          effective batch size.

    Returns:
      A `(TPU)EstimatorSpec` for the model
    """
    if isinstance(features, dict):
        features = features['feature']

    # In most cases, the default data format NCHW instead of NHWC should be
    # used for a significant performance boost on GPU. NHWC should be used
    # only if the network needs to be run on CPU since the pooling operations
    # are only supported on NHWC. TPU uses XLA compiler to figure out best layout.
    if FLAGS.data_format == 'channels_first':
        assert not FLAGS.transpose_input    # channels_first only for GPU
        features = tf.transpose(features, [0, 3, 1, 2])
        stats_shape = [3, 1, 1]
    else:
        stats_shape = [1, 1, 3]

    input_image_size = FLAGS.input_image_size
    if not input_image_size:
        input_image_size = model_builder_factory.get_model_input_size(
            FLAGS.model_name)

    if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT:
        features = tf.reshape(features,
                              [input_image_size, input_image_size, 3, -1])
        features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    has_moving_average_decay = (FLAGS.moving_average_decay > 0)
    # This is essential, if using a keras-derived model.
    tf.keras.backend.set_learning_phase(is_training)
    logging.info('Using open-source implementation.')
    override_params = {}
    if FLAGS.batch_norm_momentum is not None:
        override_params['batch_norm_momentum'] = FLAGS.batch_norm_momentum
    if FLAGS.batch_norm_epsilon is not None:
        override_params['batch_norm_epsilon'] = FLAGS.batch_norm_epsilon
    if FLAGS.dropout_rate is not None:
        override_params['dropout_rate'] = FLAGS.dropout_rate
    if FLAGS.survival_prob is not None:
        override_params['survival_prob'] = FLAGS.survival_prob
    if FLAGS.data_format:
        override_params['data_format'] = FLAGS.data_format
    if FLAGS.num_label_classes:
        override_params['num_classes'] = FLAGS.num_label_classes
    if FLAGS.depth_coefficient:
        override_params['depth_coefficient'] = FLAGS.depth_coefficient
    if FLAGS.width_coefficient:
        override_params['width_coefficient'] = FLAGS.width_coefficient

    def normalize_features(features, mean_rgb, stddev_rgb):
        """Normalize the image given the means and stddevs."""
        features -= tf.constant(mean_rgb, shape=stats_shape,
                                dtype=features.dtype)
        features /= tf.constant(stddev_rgb,
                                shape=stats_shape, dtype=features.dtype)
        return features

    def build_model():
        """Build model using the model_name given through the command line."""
        model_builder = model_builder_factory.get_model_builder(
            FLAGS.model_name)
        normalized_features = normalize_features(features, model_builder.MEAN_RGB,
                                                 model_builder.STDDEV_RGB)
        logits, _ = model_builder.build_model(
            normalized_features,
            model_name=FLAGS.model_name,
            training=is_training,
            override_params=override_params,
            model_dir=FLAGS.model_dir)
        return logits

    if params['use_bfloat16']:
        with tf.tpu.bfloat16_scope():
            logits = tf.cast(build_model(), tf.float32)
    else:
        logits = build_model()

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return NPUEstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            })

    # If necessary, in the model_fn, use params['batch_size'] instead the batch
    # size flags (--train_batch_size or --eval_batch_size).
    batch_size = params['batch_size']   # pylint: disable=unused-variable

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    cross_entropy = tf.losses.softmax_cross_entropy(
        logits=logits,
        onehot_labels=labels,
        label_smoothing=FLAGS.label_smoothing)

    # Add weight decay to the loss for non-batch-normalization variables.
    loss = cross_entropy + FLAGS.weight_decay * tf.add_n(
        [tf.nn.l2_loss(v) for v in tf.trainable_variables()
         if 'batch_normalization' not in v.name])

    global_step = tf.train.get_global_step()
    if has_moving_average_decay:
        ema = tf.train.ExponentialMovingAverage(
            decay=FLAGS.moving_average_decay, num_updates=global_step)
        ema_vars = utils.get_ema_vars()

    host_call = None
    restore_vars_dict = None
    if is_training:
        # Compute the current epoch and associated learning rate from global_step.
        current_epoch = (
            tf.cast(global_step, tf.float32) / params['steps_per_epoch'])

        scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0)
        logging.info('base_learning_rate = %f', FLAGS.base_learning_rate)
        learning_rate = utils.build_learning_rate(
            scaled_lr,
            global_step,
            params['steps_per_epoch'],
            decay_epochs=FLAGS.lr_decay_epoch,
            warmup_epochs=FLAGS.lr_warmup_epochs,
            decay_factor=FLAGS.lr_decay_factor,
            lr_decay_type=FLAGS.lr_schedule,
            total_steps=FLAGS.train_steps)
        optimizer = utils.build_optimizer(
            learning_rate,
            optimizer_name=FLAGS.optimizer,
            lars_weight_decay=FLAGS.lars_weight_decay,
            lars_epsilon=FLAGS.lars_epsilon)
        if FLAGS.use_tpu:
            # When using TPU, wrap the optimizer with CrossShardOptimizer which
            # handles synchronization details between different TPU cores. To the
            # user, this should look like regular synchronous training.
            optimizer = tf.tpu.CrossShardOptimizer(optimizer)

        # Batch normalization requires UPDATE_OPS to be added as a dependency to
        # the train operation.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step)

        if has_moving_average_decay:
            with tf.control_dependencies([train_op]):
                train_op = ema.apply(ema_vars)

        if not FLAGS.skip_host_call:
            def host_call_fn(gs, lr, ce):
                """Training host call. Creates scalar summaries for training metrics.

                This function is executed on the CPU and should not directly reference
                any Tensors in the rest of the `model_fn`. To pass Tensors from the
                model to the `metric_fn`, provide as part of the `host_call`. See
                https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec
                for more information.

                Arguments should match the list of `Tensor` objects passed as the second
                element in the tuple passed to `host_call`.

                Args:
                  gs: `Tensor with shape `[batch]` for the global_step
                  lr: `Tensor` with shape `[batch]` for the learning_rate.
                  ce: `Tensor` with shape `[batch]` for the current_epoch.

                Returns:
                  List of summary ops to run on the CPU host.
                """
                gs = gs[0]
                # Host call fns are executed FLAGS.iterations_per_loop times after one
                # TPU loop is finished, setting max_queue value to the same as number of
                # iterations will make the summary writer only flush the data to storage
                # once per loop.
                with tf2.summary.create_file_writer(
                        FLAGS.model_dir, max_queue=FLAGS.iterations_per_loop).as_default():
                    with tf2.summary.record_if(True):
                        tf2.summary.scalar('learning_rate', lr[0], step=gs)
                        tf2.summary.scalar('current_epoch', ce[0], step=gs)

                        return tf.summary.all_v2_summary_ops()

            # To log the loss, current learning rate, and epoch for Tensorboard, the
            # summary op needs to be run on the host CPU via host_call. host_call
            # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
            # dimension. These Tensors are implicitly concatenated to
            # [params['batch_size']].
            gs_t = tf.reshape(global_step, [1])
            lr_t = tf.reshape(learning_rate, [1])
            ce_t = tf.reshape(current_epoch, [1])

            host_call = (host_call_fn, [gs_t, lr_t, ce_t])

    else:
        train_op = None
        if has_moving_average_decay:
            # Load moving average variables for eval.
            restore_vars_dict = ema.variables_to_restore(ema_vars)

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:
        def metric_fn(labels, logits):
            """Evaluation metric function. Evaluates accuracy.

            This function is executed on the CPU and should not directly reference
            any Tensors in the rest of the `model_fn`. To pass Tensors from the model
            to the `metric_fn`, provide as part of the `eval_metrics`. See
            https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec
            for more information.

            Arguments should match the list of `Tensor` objects passed as the second
            element in the tuple passed to `eval_metrics`.

            Args:
              labels: `Tensor` with shape `[batch, num_classes]`.
              logits: `Tensor` with shape `[batch, num_classes]`.

            Returns:
              A dict of the metrics to return from evaluation.
            """
            labels = tf.argmax(labels, axis=1)
            predictions = tf.argmax(logits, axis=1)
            top_1_accuracy = tf.metrics.accuracy(labels, predictions)
            in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
            top_5_accuracy = tf.metrics.mean(in_top_5)

            return {
                'top_1_accuracy': top_1_accuracy,
                'top_5_accuracy': top_5_accuracy,
            }

        #eval_metrics = (metric_fn, [labels, logits])
        eval_metrics = metric_fn(labels, logits)

    num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
    logging.info('number of trainable parameters: %d', num_params)

    saver = tf.train.Saver(restore_vars_dict)

    return NPUEstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metrics,
        scaffold=tf.train.Scaffold(saver=saver))


def _verify_non_empty_string(value, field_name):
    """Ensures that a given proposed field value is a non-empty string.

    Args:
      value:  proposed value for the field.
      field_name:  string name of the field, e.g. `project`.

    Returns:
      The given value, provided that it passed the checks.

    Raises:
      ValueError:  the value is not a string, or is a blank string.
    """
    if not isinstance(value, str):
        raise ValueError(
            'Bigtable parameter "%s" must be a string.' % field_name)
    if not value:
        raise ValueError(
            'Bigtable parameter "%s" must be non-empty.' % field_name)
    return value


def _select_tables_from_flags():
    """Construct training and evaluation Bigtable selections from flags.

    Returns:
      [training_selection, evaluation_selection]
    """
    project = _verify_non_empty_string(
        FLAGS.bigtable_project or FLAGS.gcp_project,
        'project')
    instance = _verify_non_empty_string(FLAGS.bigtable_instance, 'instance')
    table = _verify_non_empty_string(FLAGS.bigtable_table, 'table')
    train_prefix = _verify_non_empty_string(FLAGS.bigtable_train_prefix,
                                            'train_prefix')
    eval_prefix = _verify_non_empty_string(FLAGS.bigtable_eval_prefix,
                                           'eval_prefix')
    column_family = _verify_non_empty_string(FLAGS.bigtable_column_family,
                                             'column_family')
    column_qualifier = _verify_non_empty_string(FLAGS.bigtable_column_qualifier,
                                                'column_qualifier')
    return [
        imagenet_input.BigtableSelection(  # pylint: disable=g-complex-comprehension
            project=project,
            instance=instance,
            table=table,
            prefix=p,
            column_family=column_family,
            column_qualifier=column_qualifier)
        for p in (train_prefix, eval_prefix)
    ]


def export(est, export_dir, input_image_size=None):
    """Export graph to SavedModel and TensorFlow Lite.

    Args:
      est: estimator instance.
      export_dir: string, exporting directory.
      input_image_size: int, input image size.

    Raises:
      ValueError: the export directory path is not specified.
    """
    if not export_dir:
        raise ValueError('The export directory path is not specified.')

    if not input_image_size:
        input_image_size = FLAGS.input_image_size
    is_cond_conv = FLAGS.model_name.startswith('efficientnet-condconv')
    # Use fixed batch size for condconv.
    batch_size = 1 if is_cond_conv else None

    logging.info('Starting to export model.')
    if (FLAGS.model_name.startswith('efficientnet-lite') or
            FLAGS.model_name.startswith('efficientnet-edgetpu')):
        # lite or edgetpu use binlinear for easier post-quantization.
        resize_method = tf.image.ResizeMethod.BILINEAR
    else:
        resize_method = None
    image_serving_input_fn = imagenet_input.build_image_serving_input_fn(
        input_image_size, batch_size=batch_size, resize_method=resize_method)
    est.export_saved_model(
        export_dir_base=export_dir,
        serving_input_receiver_fn=image_serving_input_fn)


def main(unused_argv):
    logging.set_verbosity(logging.INFO)

    print("===>>>dataset:{}".format(FLAGS.data_dir))
    print("===>>>result:{}".format(FLAGS.model_dir))
    print("===>>>train_steps:{}".format(FLAGS.train_steps))

    input_image_size = FLAGS.input_image_size
    if not input_image_size:
        input_image_size = model_builder_factory.get_model_input_size(
            FLAGS.model_name)

    if FLAGS.holdout_shards:
        holdout_images = int(FLAGS.num_train_images *
                             FLAGS.holdout_shards / 1024.0)
        FLAGS.num_train_images -= holdout_images
        if FLAGS.eval_name and 'test' in FLAGS.eval_name:
            FLAGS.holdout_shards = None  # do not use holdout if eval test set.
        else:
            FLAGS.num_eval_images = holdout_images

    # For imagenet dataset, include background label if number of output classes
    # is 1001
    include_background_label = (FLAGS.num_label_classes == 1001)

    if FLAGS.use_async_checkpointing:
        save_checkpoints_steps = None
    else:
        save_checkpoints_steps = max(100, FLAGS.iterations_per_loop)

    from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig
    config = tf.ConfigProto()
    custom_op = config.graph_options.rewrite_options.custom_optimizers.add()
    custom_op.name = "NpuOptimizer"

    # Set the precision_mode
    custom_op.parameter_map['precision_mode'].s = tf.compat.as_bytes(
        'allow_mix_precision')
    custom_op.parameter_map["use_off_line"].b = True

    # Set the dump path
    os.mkdir(FLAGS.dump_dir)
    custom_op.parameter_map['dump_path'].s = tf.compat.as_bytes(FLAGS.dump_dir)
    # Set dump debug
    custom_op.parameter_map['enable_dump_debug'].b = True
    custom_op.parameter_map['dump_debug_mode'].s = tf.compat.as_bytes('all')

    """
  # Can not profiling from session.run
  os.mkdir("/tmp/profiling")
  custom_op.parameter_map["profiling_mode"].b = True
  custom_op.parameter_map["profiling_options"].s = tf.compat.as_bytes('{"output":"/tmp/profiling","task_trace":"on"}')
  """

    config.graph_options.rewrite_options.remapping = RewriterConfig.OFF  # Must be OFF
    config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF  # Must be OFF

    os.mkdir(FLAGS.profiling_dir)
    profiling_options = '{"output":"%s","task_trace":"on"}' % FLAGS.profiling_dir
    profiling_config = ProfilingConfig(
        enable_profiling=True, profiling_options=profiling_options)

    runconfig = NPURunConfig(
        profiling_config=profiling_config,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=save_checkpoints_steps,
        log_step_count_steps=FLAGS.log_step_count_steps,
        session_config=config)  # pylint: disable=line-too-long

    # Initializes model parameters.
    params = dict(
        steps_per_epoch=FLAGS.num_train_images / FLAGS.train_batch_size,
        use_bfloat16=FLAGS.use_bfloat16,
        batch_size=FLAGS.batch_size)
    est = NPUEstimator(
        model_fn=model_fn,
        config=runconfig,
        model_dir=FLAGS.model_dir,
        params=params)

    if (FLAGS.model_name.startswith('efficientnet-lite') or
            FLAGS.model_name.startswith('efficientnet-edgetpu')):
        # lite or edgetpu use binlinear for easier post-quantization.
        resize_method = tf.image.ResizeMethod.BILINEAR
    else:
        resize_method = None
    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.

    def build_imagenet_input(is_training):
        """Generate ImageNetInput for training and eval."""
        if FLAGS.bigtable_instance:
            logging.info('Using Bigtable dataset, table %s',
                         FLAGS.bigtable_table)
            select_train, select_eval = _select_tables_from_flags()
            return imagenet_input.ImageNetBigtableInput(
                is_training=is_training,
                use_bfloat16=FLAGS.use_bfloat16,
                transpose_input=FLAGS.transpose_input,
                selection=select_train if is_training else select_eval,
                num_label_classes=FLAGS.num_label_classes,
                include_background_label=include_background_label,
                augment_name=FLAGS.augment_name,
                mixup_alpha=FLAGS.mixup_alpha,
                randaug_num_layers=FLAGS.randaug_num_layers,
                randaug_magnitude=FLAGS.randaug_magnitude,
                resize_method=resize_method)
        else:
            if FLAGS.data_dir == FAKE_DATA_DIR:
                logging.info('Using fake dataset.')
            else:
                logging.info('Using dataset: %s', FLAGS.data_dir)

            return imagenet_input.ImageNetInput(
                is_training=is_training,
                data_dir=FLAGS.data_dir,
                transpose_input=FLAGS.transpose_input,
                cache=FLAGS.use_cache and is_training,
                image_size=input_image_size,
                num_parallel_calls=FLAGS.num_parallel_calls,
                use_bfloat16=FLAGS.use_bfloat16,
                num_label_classes=FLAGS.num_label_classes,
                include_background_label=include_background_label,
                augment_name=FLAGS.augment_name,
                mixup_alpha=FLAGS.mixup_alpha,
                randaug_num_layers=FLAGS.randaug_num_layers,
                randaug_magnitude=FLAGS.randaug_magnitude,
                resize_method=resize_method,
                holdout_shards=FLAGS.holdout_shards)

    imagenet_train = build_imagenet_input(is_training=True)
    imagenet_eval = build_imagenet_input(is_training=False)

    if FLAGS.mode == 'eval':
        eval_steps = FLAGS.num_eval_images // FLAGS.eval_batch_size
        # Run evaluation when there's a new checkpoint
        for ckpt in tf.train.checkpoints_iterator(
                FLAGS.model_dir, timeout=FLAGS.eval_timeout):
            logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time()  # This time will include compilation time
                eval_results = est.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=eval_steps,
                    checkpoint_path=ckpt,
                    name=FLAGS.eval_name)
                elapsed_time = int(time.time() - start_timestamp)
                logging.info('Eval results: %s. Elapsed seconds: %d',
                             eval_results, elapsed_time)
                if FLAGS.archive_ckpt:
                    utils.archive_ckpt(
                        eval_results, eval_results['top_1_accuracy'], ckpt)

                # Terminate eval job when final checkpoint is reached
                try:
                    current_step = int(os.path.basename(ckpt).split('-')[1])
                except IndexError:
                    logging.info('%s has no global step info: stop!', ckpt)
                    break

                if current_step >= FLAGS.train_steps:
                    logging.info(
                        'Evaluation finished after training step %d', current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint', ckpt)
    else:   # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
        current_step = estimator._load_global_step_from_checkpoint_dir(
            FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long

        logging.info(
            'Training for %d steps (%.2f epochs in total). Current'
            ' step %d.', FLAGS.train_steps,
            FLAGS.train_steps / params['steps_per_epoch'], current_step)

        start_timestamp = time.time()  # This time will include compilation time

        if FLAGS.mode == 'train':
            hooks = []
            if FLAGS.use_async_checkpointing:
                try:
                    from tensorflow.contrib.tpu.python.tpu import async_checkpoint  # pylint: disable=g-import-not-at-top
                except ImportError as e:
                    logging.exception(
                        'Async checkpointing is not supported in TensorFlow 2.x')
                    raise e

                hooks.append(
                    async_checkpoint.AsyncCheckpointSaverHook(
                        checkpoint_dir=FLAGS.model_dir,
                        save_steps=max(100, FLAGS.iterations_per_loop)))
            est.train(
                input_fn=imagenet_train.input_fn,
                max_steps=FLAGS.train_steps,
                hooks=hooks)

        else:
            assert FLAGS.mode == 'train_and_eval'
            while current_step < FLAGS.train_steps:
                # Train for up to steps_per_eval number of steps.
                # At the end of training, a checkpoint will be written to --model_dir.
                next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                                      FLAGS.train_steps)
                est.train(input_fn=imagenet_train.input_fn,
                          max_steps=next_checkpoint)
                current_step = next_checkpoint

                logging.info('Finished training up to step %d. Elapsed seconds %d.',
                             next_checkpoint, int(time.time() - start_timestamp))

                # Evaluate the model on the most recent model in --model_dir.
                # Since evaluation happens in batches of --eval_batch_size, some images
                # may be excluded modulo the batch size. As long as the batch size is
                # consistent, the evaluated images are also consistent.
                logging.info('Starting to evaluate.')
                eval_results = est.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=FLAGS.num_eval_images // FLAGS.eval_batch_size,
                    name=FLAGS.eval_name)
                logging.info('Eval results at step %d: %s',
                             next_checkpoint, eval_results)
                ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
                if FLAGS.archive_ckpt:
                    utils.archive_ckpt(
                        eval_results, eval_results['top_1_accuracy'], ckpt)

            elapsed_time = int(time.time() - start_timestamp)
            logging.info('Finished training up to step %d. Elapsed seconds %d.',
                         FLAGS.train_steps, elapsed_time)
    if FLAGS.export_dir:
        export(est, FLAGS.export_dir, input_image_size)

    #from help_modelarts import modelarts_result2obs
    #modelarts_result2obs(FLAGS)


if __name__ == '__main__':
    flags.mark_flag_as_required("data_dir")
    flags.mark_flag_as_required("model_dir")
    flags.mark_flag_as_required("obs_dir")
    flags.mark_flag_as_required("model_name")

    app.run(main)
